using Fantasy.SourceGenerator.Common;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Fantasy.SourceGenerator.Generators
{
    [Generator]
    public sealed class MessageHandlerGenerator : IIncrementalGenerator
    {
        public void Initialize(IncrementalGeneratorInitializationContext context)
        {
            // 查找所有实现了消息相关接口的类
            var messageTypes = context.SyntaxProvider
                .CreateSyntaxProvider(
                    predicate: static (node, _) => IsMessageHandlerClass(node),
                    transform: static (ctx, _) => GetMessageTypeInfo(ctx))
                .Where(static info => info != null)
                .Collect();
            // 组合编译信息和找到的类型
            var compilationAndTypes = context.CompilationProvider.Combine(messageTypes);
            // 注册源代码输出
            context.RegisterSourceOutput(compilationAndTypes, static (spc, source) =>
            {
                if (CompilationHelper.IsSourceGeneratorDisabled(source.Left))
                {
                    return;
                }
                
                if (!CompilationHelper.HasFantasyDefine(source.Left))
                {
                    return;
                } 
               
                if (source.Left.GetTypeByMetadataName("Fantasy.Assembly.INetworkProtocolRegistrar") == null)
                {
                    return;
                }

                GenerateRegistrationCode(spc, source.Left, source.Right!);
            });
        }

        private static void GenerateRegistrationCode(
            SourceProductionContext context,
            Compilation compilation,
            IEnumerable<MessageHandlerInfo> messageHandlerInfos)
        {
            var messageHandlers = new List<MessageHandlerInfo>();
            var routeMessageHandlers = new List<MessageHandlerInfo>();
            
            foreach (var messageHandlerInfo in messageHandlerInfos)
            {
                switch (messageHandlerInfo.HandlerType)
                {
                    case HandlerType.MessageHandler:
                    {
                        messageHandlers.Add(messageHandlerInfo);
                        break;
                    }
                    case HandlerType.RouteMessageHandler:
                    {
                        routeMessageHandlers.Add(messageHandlerInfo);
                        break;
                    }
                }
            }
            
            var markerClassName = compilation.GetAssemblyName("MessageHandlerResolverRegistrar", out var assemblyName, out _);
            var builder = new SourceCodeBuilder();
            builder.AppendLine(GeneratorConstants.AutoGeneratedHeader);
            builder.AddUsings(
                "System",
                "System.Collections.Generic",
                "Fantasy.Assembly",
                "Fantasy.DataStructure.Dictionary",
                "Fantasy.Network.Interface",
                "Fantasy.Network",
                "Fantasy.Entitas",
                "Fantasy.Async",
                "System.Runtime.CompilerServices"
            );
            builder.AppendLine();
             builder.BeginDefaultNamespace();
            builder.AddXmlComment($"Auto-generated message handler registration class for {assemblyName}");
            builder.BeginClass(markerClassName, "internal sealed", "global::Fantasy.Assembly.IMessageHandlerResolver");
            // 生成 GenerateCode 方法
            GenerateCode(builder, messageHandlers, routeMessageHandlers);
            // 结束类和命名空间
            builder.EndClass();
            builder.EndNamespace();
            // 输出源代码
            context.AddSource($"{markerClassName}.g.cs", builder.ToString());
        }

        private static void GenerateCode(SourceCodeBuilder builder, List<MessageHandlerInfo> messageHandlers, List<MessageHandlerInfo> routeMessageHandlers)
        {
            // MessageHandlerOpCodes
            builder.AddXmlComment("MessageHandlerOpCodes");
            builder.BeginMethod("public uint[] MessageHandlerOpCodes()");
            try
            {
                if (messageHandlers.Any())
                {
                    builder.AppendLine($"var uintArray = new uint[{messageHandlers.Count}];");
                    for (var i = 0; i < messageHandlers.Count; i++)
                    {
                        builder.AppendLine($"uintArray[{i}] = {messageHandlers[i].OpCode};");
                    }

                    builder.AppendLine("return uintArray;");
                }
                else
                {
                    builder.AppendLine("return Array.Empty<uint>();");
                }
            }
            catch (Exception e)
            {
                Console.WriteLine(e);
                throw;
            }
            builder.EndMethod();
            // MessageHandlers
            builder.AddXmlComment("MessageHandlers");
            builder.BeginMethod("public global::System.Func<global::Fantasy.Network.Session, uint, object, global::Fantasy.Async.FTask>[] MessageHandlers()");
            try
            {
                if (messageHandlers.Any())
                {
                    builder.AppendLine($"var handlerArray = new global::System.Func<global::Fantasy.Network.Session, uint, object, global::Fantasy.Async.FTask>[{messageHandlers.Count}];");
                    for (var i = 0; i < messageHandlers.Count; i++)
                    {
                        builder.AppendLine($"handlerArray[{i}] = new {messageHandlers[i].TypeFullName}().Handle;;");
                    }

                    builder.AppendLine("return handlerArray;");
                }
                else
                {
                    builder.AppendLine("return Array.Empty<global::System.Func<global::Fantasy.Network.Session, uint, object, global::Fantasy.Async.FTask>>();");
                }
            }
            catch (Exception e)
            {
                Console.WriteLine(e);
                throw;
            }
            builder.EndMethod();
            // AddressMessageHandlerOpCodes
            builder.AddXmlComment("AddressMessageHandlerOpCodes");
            builder.BeginMethod("public uint[] AddressMessageHandlerOpCodes()");
            try
            {
                if (routeMessageHandlers.Any())
                {
                    builder.AppendLine($"var uintArray = new uint[{routeMessageHandlers.Count}];");
                    for (var i = 0; i < routeMessageHandlers.Count; i++)
                    {
                        builder.AppendLine($"uintArray[{i}] = {routeMessageHandlers[i].OpCode};");
                    }

                    builder.AppendLine("return uintArray;");
                }
                else
                {
                    builder.AppendLine("return Array.Empty<uint>();");
                }
            }
            catch (Exception e)
            {
                Console.WriteLine(e);
                throw;
            }
            builder.EndMethod();
            // AddressMessageHandler
            builder.AppendLine("#if FANTASY_NET", false);
            builder.AddXmlComment("AddressMessageHandler");
            builder.BeginMethod("public global::System.Func<global::Fantasy.Network.Session, global::Fantasy.Entitas.Entity, uint, object, global::Fantasy.Async.FTask>[] AddressMessageHandler()");
            try
            {
                if (routeMessageHandlers.Any())
                {
                    builder.AppendLine($"var handlerArray = new global::System.Func<global::Fantasy.Network.Session, global::Fantasy.Entitas.Entity, uint, object, global::Fantasy.Async.FTask>[{routeMessageHandlers.Count}];");
                    for (var i = 0; i < routeMessageHandlers.Count; i++)
                    {
                        builder.AppendLine($"handlerArray[{i}] = new {routeMessageHandlers[i].TypeFullName}().Handle;;");
                    }

                    builder.AppendLine("return handlerArray;");
                }
                else
                {
                    builder.AppendLine("return Array.Empty<global::System.Func<global::Fantasy.Network.Session, global::Fantasy.Entitas.Entity, uint, object, global::Fantasy.Async.FTask>>();");
                }
            }
            catch (Exception e)
            {
                Console.WriteLine(e);
                throw;
            }
            builder.EndMethod();
            builder.AppendLine("#endif", false);
        }

        private static bool IsMessageHandlerClass(SyntaxNode node)
        {
            if (node is not ClassDeclarationSyntax classDecl)
            {
                return false;
            }
            
            if (classDecl.BaseList == null || !classDecl.BaseList.Types.Any())
            {
                return false;
            }
            
            foreach (var baseType in classDecl.BaseList.Types)
            {
                var typeName = baseType.Type.ToString();
                
                if (typeName.Contains("IMessageHandler") ||
                    typeName.Contains("IAddressMessageHandler") ||
                    typeName.Contains("Message<") ||
                    typeName.Contains("MessageRPC<") ||
                    typeName.Contains("Address<") ||
                    typeName.Contains("AddressRPC<") ||
                    typeName.Contains("Addressable<") ||
                    typeName.Contains("AddressableRPC<") ||
                    typeName.Contains("Roaming<") ||
                    typeName.Contains("RoamingRPC<"))
                {
                    return true;
                }
            }

            return false;
        }

        private static MessageHandlerInfo? GetMessageTypeInfo(GeneratorSyntaxContext context)
        {
            var classDecl = (ClassDeclarationSyntax)context.Node;

            if (context.SemanticModel.GetDeclaredSymbol(classDecl) is not INamedTypeSymbol symbol ||
                !symbol.IsInstantiable())
            {
                return null;
            }
            
            var baseType = symbol.BaseType;
            
            if (baseType is not { IsGenericType: true } || baseType.TypeArguments.Length <= 0)
            {
                return null;
            }
            
            var baseTypeName = baseType.OriginalDefinition.ToDisplayString();

            switch (baseTypeName)
            {
                case "Fantasy.Network.Interface.Message<T>":
                case "Fantasy.Network.Interface.MessageRPC<TRequest, TResponse>":
                {
                    return new MessageHandlerInfo(
                        HandlerType.MessageHandler,
                        symbol.GetFullName(),
                        symbol.Name,
                        GetOpCode(context, baseType, 0));
                }
                case "Fantasy.Network.Interface.Address<TEntity, TMessage>":
                case "Fantasy.Network.Interface.AddressRPC<TEntity, TAddressRequest, TAddressResponse>":
                case "Fantasy.Network.Interface.Addressable<TEntity, TMessage>":
                case "Fantasy.Network.Interface.AddressableRPC<TEntity, TAddressRequest, TAddressResponse>":
                case "Fantasy.Network.Interface.Roaming<TEntity, TMessage>":
                case "Fantasy.Network.Interface.RoamingRPC<TEntity, TAddressRequest, TAddressResponse>":
                {
                    return new MessageHandlerInfo(
                        HandlerType.RouteMessageHandler,
                        symbol.GetFullName(),
                        symbol.Name,
                        GetOpCode(context, baseType, 1));
                }
            }

            return null;
        }

        private static uint? GetOpCode(GeneratorSyntaxContext context, INamedTypeSymbol baseType, int index)
        {
            if (baseType.TypeArguments.Length <= index)
            {
                return null;
            }

            var messageType = (INamedTypeSymbol)baseType.TypeArguments[index];
            var messageName = messageType.Name;
            var compilation = context.SemanticModel.Compilation;

            // 策略1：从消息类型所在程序集中搜索 OpCode 类
            var messageAssembly = messageType.ContainingAssembly;
            var namespaceName = messageType.ContainingNamespace.ToDisplayString();

            // 遍历程序集中的所有类型，查找 OuterOpcode 或 InnerOpcode
            var opCodeTypeNames = new[] { "OuterOpcode", "InnerOpcode" };
            foreach (var opCodeTypeName in opCodeTypeNames)
            {
                var opCodeType = FindTypeInAssembly(messageAssembly.GlobalNamespace, namespaceName, opCodeTypeName);
                if (opCodeType != null)
                {
                    var opCodeField = opCodeType.GetMembers(messageName).OfType<IFieldSymbol>().FirstOrDefault();
                    if (opCodeField != null && opCodeField.IsConst && opCodeField.ConstantValue is uint constValue)
                    {
                        return constValue;
                    }
                }
            }

            // 策略2：如果策略1失败，尝试从 OpCode() 方法的语法树中解析（仅适用于同项目中的消息）
            var opCodeMethod = messageType.GetMembers("OpCode").OfType<IMethodSymbol>().FirstOrDefault();
            if (opCodeMethod != null)
            {
                var opCodeSyntax = opCodeMethod.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as MethodDeclarationSyntax;
                if (opCodeSyntax?.Body != null)
                {
                    var returnStatement = opCodeSyntax.Body.DescendantNodes()
                        .OfType<ReturnStatementSyntax>()
                        .FirstOrDefault();

                    if (returnStatement?.Expression != null)
                    {
                        var syntaxTree = opCodeSyntax.SyntaxTree;

                        if (compilation.ContainsSyntaxTree(syntaxTree))
                        {
                            var semanticModel = compilation.GetSemanticModel(syntaxTree);

                            // 尝试符号解析
                            var symbolInfo = semanticModel.GetSymbolInfo(returnStatement.Expression);
                            if (symbolInfo.Symbol is IFieldSymbol fieldSymbol && fieldSymbol.IsConst && fieldSymbol.ConstantValue is uint constValue2)
                            {
                                return constValue2;
                            }

                            // 尝试常量值解析
                            var constantValue = semanticModel.GetConstantValue(returnStatement.Expression);
                            if (constantValue.HasValue && constantValue.Value is uint uintValue)
                            {
                                return uintValue;
                            }
                        }
                    }
                }
            }

            return null;
        }

        // 辅助方法：在程序集的命名空间中递归查找指定类型
        private static INamedTypeSymbol? FindTypeInAssembly(INamespaceSymbol namespaceSymbol, string targetNamespace, string typeName)
        {
            // 如果当前命名空间匹配目标命名空间，查找类型
            if (namespaceSymbol.ToDisplayString() == targetNamespace)
            {
                var type = namespaceSymbol.GetTypeMembers(typeName).FirstOrDefault();
                if (type != null)
                {
                    return type;
                }
            }

            // 递归搜索子命名空间
            foreach (var childNamespace in namespaceSymbol.GetNamespaceMembers())
            {
                var result = FindTypeInAssembly(childNamespace, targetNamespace, typeName);
                if (result != null)
                {
                    return result;
                }
            }

            return null;
        }

        private enum HandlerType
        {
            None,
            MessageHandler,
            RouteMessageHandler
        }

        private sealed record MessageHandlerInfo(
            HandlerType HandlerType,
            string TypeFullName,
            string TypeName,
            uint? OpCode);
    }
}